{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing your own model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we show how to implement your own model and test it on a dataset. \n",
    "\n",
    "This particular example uses the MUTAG dataset, uses an hypergraph lifting to create hypergraphs, and defines a model to work on them. \n",
    "\n",
    "We train the model using the appropriate training and validation datasets, and finally test it on the test dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>\n",
    "&emsp;[1. Imports](##sec1)\n",
    "\n",
    "&emsp;[2. Configurations and utilities](##sec2)\n",
    "\n",
    "&emsp;[3. Loading the data](##sec3)\n",
    "\n",
    "&emsp;[4. Backbone definition](##sec4)\n",
    "\n",
    "&emsp;[5. Model initialization](##sec5)\n",
    "\n",
    "&emsp;[6. Training](##sec6)\n",
    "\n",
    "&emsp;[7. Testing the model](##sec7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Imports <a class=\"anchor\" id=\"sec1\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import lightning as pl\n",
    "# Hydra related imports\n",
    "from omegaconf import OmegaConf\n",
    "# Data related imports\n",
    "from topobench.data.loaders.graph import TUDatasetLoader\n",
    "from topobench.dataloader.dataloader import TBDataloader\n",
    "from topobench.data.preprocessor import PreProcessor\n",
    "# Model related imports\n",
    "from topobench.model.model import TBModel\n",
    "from topomodelx.nn.simplicial.scn2 import SCN2\n",
    "from topobench.nn.wrappers.simplicial import SCNWrapper\n",
    "from topobench.nn.encoders import AllCellFeatureEncoder\n",
    "from topobench.nn.readouts import PropagateSignalDown\n",
    "# Optimization related imports\n",
    "from topobench.loss.loss import TBLoss\n",
    "from topobench.optimizer import TBOptimizer\n",
    "from topobench.evaluator.evaluator import TBEvaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Configurations and utilities <a class=\"anchor\" id=\"sec2\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Configurations can be specified using yaml files or directly specified in your code like in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_config = {\n",
    "    \"data_domain\": \"graph\",\n",
    "    \"data_type\": \"TUDataset\",\n",
    "    \"data_name\": \"MUTAG\",\n",
    "    \"data_dir\": \"./data/MUTAG/\"}\n",
    "\n",
    "transform_config = { \"khop_lifting\":\n",
    "    {\"transform_type\": \"lifting\",\n",
    "    \"transform_name\": \"HypergraphKHopLifting\",\n",
    "    \"k_value\": 1,}\n",
    "}\n",
    "\n",
    "split_config = {\n",
    "    \"learning_setting\": \"inductive\",\n",
    "    \"split_type\": \"random\",\n",
    "    \"data_seed\": 0,\n",
    "    \"data_split_dir\": \"./data/MUTAG/splits/\",\n",
    "    \"train_prop\": 0.5,\n",
    "}\n",
    "\n",
    "in_channels = 7\n",
    "out_channels = 2\n",
    "dim_hidden = 16\n",
    "\n",
    "readout_config = {\n",
    "    \"readout_name\": \"PropagateSignalDown\",\n",
    "    \"num_cell_dimensions\": 1,\n",
    "    \"hidden_dim\": dim_hidden,\n",
    "    \"out_channels\": out_channels,\n",
    "    \"task_level\": \"graph\",\n",
    "    \"pooling_type\": \"sum\",\n",
    "}\n",
    "\n",
    "loss_config = {\n",
    "    \"dataset_loss\": \n",
    "        {\n",
    "            \"task\": \"classification\", \n",
    "            \"loss_type\": \"cross_entropy\"\n",
    "        }\n",
    "}\n",
    "\n",
    "evaluator_config = {\"task\": \"classification\",\n",
    "                    \"num_classes\": out_channels,\n",
    "                    \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n",
    "\n",
    "optimizer_config = {\"optimizer_id\": \"Adam\",\n",
    "                    \"parameters\":\n",
    "                        {\"lr\": 0.001,\"weight_decay\": 0.0005}\n",
    "                    }\n",
    "\n",
    "loader_config = OmegaConf.create(loader_config)\n",
    "transform_config = OmegaConf.create(transform_config)\n",
    "split_config = OmegaConf.create(split_config)\n",
    "readout_config = OmegaConf.create(readout_config)\n",
    "loss_config = OmegaConf.create(loss_config)\n",
    "evaluator_config = OmegaConf.create(evaluator_config)\n",
    "optimizer_config = OmegaConf.create(optimizer_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Loading the data <a class=\"anchor\" id=\"sec3\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example we use the MUTAG dataset. It is a graph dataset and we use the k-hop lifting to transform the graphs into hypergraphs. \n",
    "\n",
    "We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchX) to learn more about the various liftings offered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transform parameters are the same, using existing data_dir: data/MUTAG/MUTAG/khop_lifting/1116229528\n"
     ]
    }
   ],
   "source": [
    "graph_loader = TUDatasetLoader(loader_config)\n",
    "\n",
    "dataset, dataset_dir = graph_loader.load()\n",
    "\n",
    "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n",
    "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n",
    "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Backbone definition <a class=\"anchor\" id=\"sec4\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To implement a new model we only need to define the forward method.\n",
    "\n",
    "With a hypergraph with $n$ nodes and $m$ hyperedges this model simply calculates the hyperedge features as $X_1 = B_1 \\cdot X_0$ where $B_1 \\in \\mathbb{R}^{n \\times m}$ is the incidence matrix, where $B_{ij}=1$ if node $i$ belongs to hyperedge $j$ and is 0 otherwise.\n",
    "\n",
    "Then the outputs are computed as $X^{'}_0=\\text{ReLU}(W_0 \\cdot X_0 + B_0)$ and $X^{'}_1=\\text{ReLU}(W_1 \\cdot X_1 + B_1)$, by simply using two linear layers with ReLU activation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myModel(pl.LightningModule):\n",
    "    def __init__(self, dim_hidden):\n",
    "        super().__init__()\n",
    "        self.dim_hidden = dim_hidden\n",
    "        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)\n",
    "        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)\n",
    "\n",
    "    def forward(self, batch):\n",
    "        x_0 = batch.x_0\n",
    "        incidence_hyperedges = batch.incidence_hyperedges\n",
    "        x_1 = torch.sparse.mm(incidence_hyperedges, x_0)\n",
    "        \n",
    "        x_0 = self.linear_0(x_0)\n",
    "        x_0 = torch.relu(x_0)\n",
    "        x_1 = self.linear_1(x_1)\n",
    "        x_1 = torch.relu(x_1)\n",
    "        \n",
    "        model_out = {\"labels\": batch.y, \"batch_0\": batch.batch_0}\n",
    "        model_out[\"x_0\"] = x_0\n",
    "        model_out[\"hyperedge\"] = x_1\n",
    "        return model_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Model initialization <a class=\"anchor\" id=\"sec5\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the model is defined we can create the TBModel, which takes care of implementing everything else that is needed to train the model. \n",
    "\n",
    "First we need to implement a few classes to specify the behaviour of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone = myModel(dim_hidden)\n",
    "\n",
    "readout = PropagateSignalDown(**readout_config)\n",
    "loss = TBLoss(**loss_config)\n",
    "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)\n",
    "\n",
    "evaluator = TBEvaluator(**evaluator_config)\n",
    "optimizer = TBOptimizer(**optimizer_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can instantiate the TBModel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TBModel(backbone=backbone,\n",
    "                 backbone_wrapper=None,\n",
    "                 readout=readout,\n",
    "                 loss=loss,\n",
    "                 feature_encoder=feature_encoder,\n",
    "                 evaluator=evaluator,\n",
    "                 optimizer=optimizer,\n",
    "                 compile=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TBModel(backbone=myModel(\n",
       "  (linear_0): Linear(in_features=16, out_features=16, bias=True)\n",
       "  (linear_1): Linear(in_features=16, out_features=16, bias=True)\n",
       "), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[7], out_channels=16, dimensions=range(0, 1)))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Training <a class=\"anchor\" id=\"sec6\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use the `lightning` trainer to train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
      "\n",
      "  | Name            | Type                  | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | feature_encoder | AllCellFeatureEncoder | 448    | train\n",
      "1 | backbone        | myModel               | 544    | train\n",
      "2 | readout         | PropagateSignalDown   | 34     | train\n",
      "3 | val_acc_best    | MeanMetric            | 0      | train\n",
      "------------------------------------------------------------------\n",
      "1.0 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.0 K     Total params\n",
      "0.004     Total estimated model params size (MB)\n",
      "13        Modules in train mode\n",
      "0         Modules in eval mode\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=50` reached.\n"
     ]
    }
   ],
   "source": [
    "# Increase the number of epochs to get better results\n",
    "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False, log_every_n_steps=1)\n",
    "\n",
    "trainer.fit(model, datamodule)\n",
    "train_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Training metrics\n",
      " --------------------------\n",
      "train/accuracy:       0.7979\n",
      "train/precision:      0.8114\n",
      "train/recall:         0.7181\n",
      "val/loss:             0.5168\n",
      "val/accuracy:         0.7234\n",
      "val/precision:        0.6992\n",
      "val/recall:           0.6021\n",
      "train/loss:           0.4556\n"
     ]
    }
   ],
   "source": [
    "print('      Training metrics\\n', '-'*26)\n",
    "for key in train_metrics:\n",
    "    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Testing the model <a class=\"anchor\" id=\"sec7\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can test the model and obtain the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test/accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7446808218955994     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.48221755027770996    </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test/precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7910714149475098     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test/recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6598039269447327     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test/accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7446808218955994    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.48221755027770996   \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test/precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7910714149475098    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test/recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6598039269447327    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.test(model, datamodule)\n",
    "test_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Testing metrics\n",
      " -------------------------\n",
      "test/loss:           0.4822\n",
      "test/accuracy:       0.7447\n",
      "test/precision:      0.7911\n",
      "test/recall:         0.6598\n"
     ]
    }
   ],
   "source": [
    "print('      Testing metrics\\n', '-'*25)\n",
    "for key in test_metrics:\n",
    "    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "topobench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
